# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Upscale model integration."""

import io
import uuid
from typing import Tuple
from PIL import Image
from google import genai
from google.genai import types
from config.default import Default
from common.storage import store_to_gcs, download_from_gcs

cfg = Default()

UPSCALE_MODEL = "imagen-4.0-upscale-preview"

def get_image_resolution(image_data: bytes | str) -> str:
    """Gets the resolution of an image from GCS URI or bytes."""
    if isinstance(image_data, str) and image_data.startswith("gs://"):
        try:
            image_bytes = download_from_gcs(image_data)
        except Exception as e:
            print(f"Error downloading image for resolution check: {e}")
            return "Unknown"
    elif isinstance(image_data, bytes):
        image_bytes = image_data
    else:
        return "Unknown"
        
    try:
        with Image.open(io.BytesIO(image_bytes)) as img:
            return f"{img.width}x{img.height}"
    except Exception as e:
        print(f"Error getting resolution: {e}")
        return "Unknown"

def upscale_image(input_gcs_uri: str, upscale_factor: str) -> Tuple[str, str, str]:
    """
    Upscales an image using Imagen 4.0 Upscale.

    Args:
        input_gcs_uri: GCS URI of the image to upscale.
        upscale_factor: 'x2', 'x3', or 'x4'.

    Returns:
        Tuple of (output_gcs_uri, original_resolution, upscaled_resolution)
    """
    client = genai.Client(vertexai=True, project=cfg.PROJECT_ID, location=cfg.LOCATION)

    # Get original resolution
    original_resolution = get_image_resolution(input_gcs_uri)

    response = client.models.upscale_image(
        model=UPSCALE_MODEL,
        image=types.Image(gcs_uri=input_gcs_uri),
        upscale_factor=upscale_factor,
        config=types.UpscaleImageConfig(
            output_mime_type='image/png',
        ),
    )

    if not response.generated_images:
        raise RuntimeError("Upscale failed: No images generated.")

    generated_image = response.generated_images[0].image
    
    # Try to get bytes from the generated image object
    if hasattr(generated_image, 'image_bytes'):
        image_data = generated_image.image_bytes
    elif hasattr(generated_image, '_image_bytes'):
         image_data = generated_image._image_bytes
    else:
        # If we can't get bytes directly, we might need to check if it's already a PIL image
        # or if there's another way to extract it.
        # For now, assume standard SDK behavior.
        raise RuntimeError(f"Could not extract image bytes from response: {type(generated_image)}")

    # Get upscaled resolution
    upscaled_resolution = get_image_resolution(image_data)

    # Store to GCS
    file_name = f"upscaled_{uuid.uuid4()}.png"
    output_gcs_uri = store_to_gcs(
        folder="upscaled_images",
        file_name=file_name,
        mime_type="image/png",
        contents=image_data,
    )

    return output_gcs_uri, original_resolution, upscaled_resolution
